import argparse


def parse_args():
    parser = argparse.ArgumentParser(__doc__)
    parser.add_argument(
        "--task_name",
        default=None,
        type=str,
        required=True,
        help="The name of the task to train.")
    parser.add_argument(
        "--model_name_or_path",
        default='bert-base-uncased',
        type=str,
        help="Path to pre-trained bert model or shortcut name.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        help="The output directory where the checkpoints will be saved.")
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        help="The directory where the dataset will be load.")
    parser.add_argument(
        "--init_from_ckpt",
        default=None,
        type=str,
        help="The path of checkpoint to be loaded.")
    parser.add_argument(
        "--max_seq_len",
        default=None,
        type=int,
        help="The maximum total input sequence length after tokenization for trainng. "
        "Sequences longer than this will be truncated, sequences shorter will be padded."
    )
    parser.add_argument(
        "--test_max_seq_len",
        default=None,
        type=int,
        help="The maximum total input sequence length after tokenization for testing. "
        "Sequences longer than this will be truncated, sequences shorter will be padded."
    )
    parser.add_argument(
        "--batch_size",
        default=None,
        type=int,
        help="Batch size per GPU/CPU for training.")
    parser.add_argument(
        "--test_batch_size",
        default=None,
        type=int,
        help="Batch size per GPU/CPU for testing.")
    parser.add_argument(
        "--learning_rate",
        default=None,
        type=float,
        help="The initial learning rate for Adam.")
    parser.add_argument(
        "--weight_decay",
        default=0.01,
        type=float,
        help="Weight decay if we apply some.")
    parser.add_argument(
        "--epochs",
        default=None,
        type=int,
        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--logging_steps",
        default=None,
        type=int,
        help="Log every X updates steps.")
    parser.add_argument(
        "--save_steps",
        default=None,
        type=int,
        help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--seed", default=42, type=int, help="Random seed for initialization.")
    parser.add_argument(
        "--n_gpu",
        default=1,
        type=int,
        help="The number of gpus to use, 0 for cpu.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help="The proportion of warmup.")
    parser.add_argument(
        '--max_grad_norm',
        default=1.0,
        type=float,
        help='The max value of grad norm.')
    parser.add_argument(
        "--do_train", default=True, type=eval, help="Whether training.")
    parser.add_argument(
        "--do_eval", default=True, type=eval, help="Whether evaluation.")
    parser.add_argument(
        "--do_test", default=True, type=eval, help="Whether testing.")

    args = parser.parse_args()
    return args


def set_default_args(args):
    args.task_name = args.task_name.lower()
    if args.task_name == 'udc':
        if not args.save_steps:
            args.save_steps = 1000
        if not args.logging_steps:
            args.logging_steps = 100
        if not args.epochs:
            args.epochs = 2
        if not args.max_seq_len:
            args.max_seq_len = 210
        if not args.test_batch_size:
            args.test_batch_size = 100
    elif args.task_name == 'dstc2':
        if not args.save_steps:
            args.save_steps = 400
        if not args.logging_steps:
            args.logging_steps = 20
        if not args.epochs:
            args.epochs = 40
        if not args.learning_rate:
            args.learning_rate = 5e-5
        if not args.max_seq_len:
            args.max_seq_len = 256
        if not args.test_max_seq_len:
            args.test_max_seq_len = 512
    elif args.task_name == 'atis_slot':
        if not args.save_steps:
            args.save_steps = 100
        if not args.logging_steps:
            args.logging_steps = 10
        if not args.epochs:
            args.epochs = 50
    elif args.task_name == 'atis_intent':
        if not args.save_steps:
            args.save_steps = 100
        if not args.logging_steps:
            args.logging_steps = 10
        if not args.epochs:
            args.epochs = 20
    elif args.task_name == 'mrda':
        if not args.save_steps:
            args.save_steps = 500
        if not args.logging_steps:
            args.logging_steps = 200
        if not args.epochs:
            args.epochs = 7
    elif args.task_name == 'swda':
        if not args.save_steps:
            args.save_steps = 500
        if not args.logging_steps:
            args.logging_steps = 200
        if not args.epochs:
            args.epochs = 3
    else:
        raise ValueError('Not support task: %s.' % args.task_name)

    if not args.data_dir:
        args.data_dir = './DGU_datasets/' + args.task_name
    if not args.output_dir:
        args.output_dir = './checkpoints/' + args.task_name
    if not args.learning_rate:
        args.learning_rate = 2e-5
    if not args.batch_size:
        args.batch_size = 32
    if not args.test_batch_size:
        args.test_batch_size = args.batch_size
    if not args.max_seq_len:
        args.max_seq_len = 128
    if not args.test_max_seq_len:
        args.test_max_seq_len = args.max_seq_len
